package com.xiaofan.scala

import org.apache.flink.api.common.state.{MapState, MapStateDescriptor}
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeHint, TypeInformation}
import org.apache.flink.api.java.typeutils.ListTypeInfo
import org.apache.flink.streaming.api.TimeCharacteristic
import org.apache.flink.streaming.api.datastream.BroadcastStream
import org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction
import org.apache.flink.streaming.api.scala.{StreamExecutionEnvironment, _}
import org.apache.flink.util.Collector

import scala.collection.JavaConversions._

/**
 * 注意： 这里有一个类型系统问题：后续解决
 * private val mapStateDesc: MapStateDescriptor[String, List[Item]] =
 * new MapStateDescriptor[String, List[Item]]("items", BasicTypeInfo.STRING_TYPE_INFO, null)
 */


case class Color()

case class Shape()

case class Item(color: Color, shape: Shape)

case class Rule(name: String, first: Shape, second: Shape)

object KeyedBroadcastProcessFunctionD_0003 {
  def main(args: Array[String]): Unit = {

    val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment
    env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)

    val itemStream: DataStream[Item] = env.fromElements(Item(Color(), Shape()), Item(Color(), Shape()))
    val ruleStream: DataStream[Rule] = env.fromElements(Rule("", Shape(), Shape()), Rule("", Shape(), Shape()))

    val colorPartitionedStream: KeyedStream[Item, Color] = itemStream.keyBy(_.color)

    val ruleStateDescriptor: MapStateDescriptor[String, Rule] = new MapStateDescriptor[String, Rule]("RulesBroadcastState", BasicTypeInfo.STRING_TYPE_INFO, TypeInformation.of(new TypeHint[Rule] {}))

    val ruleBroadcastStream: BroadcastStream[Rule] = ruleStream.broadcast(ruleStateDescriptor)

    val out: DataStream[String] = colorPartitionedStream
      .connect(ruleBroadcastStream)
      .process(new KeyedBroadcastProcessFunction[Color, Item, Rule, String]() {

        private val mapStateDesc: MapStateDescriptor[String, List[Item]] = new MapStateDescriptor[String, List[Item]]("items", BasicTypeInfo.STRING_TYPE_INFO, null)

        override def processElement(value: Item, ctx: KeyedBroadcastProcessFunction[Color, Item, Rule, String]#ReadOnlyContext, out: Collector[String]) = {
          val state: MapState[String, List[Item]] = getRuntimeContext.getMapState(mapStateDesc)
          val shape: Shape = value.shape

          for (element <- ctx.getBroadcastState(ruleStateDescriptor).immutableEntries()) {
            val ruleName: String = element.getKey
            val rule: Rule = element.getValue

            var stored: List[Item] = state.get(ruleName)
            if (stored == null) {
              stored = List[Item]()
            }

            if (shape.equals(rule.second) && !stored.isEmpty) {
              for (i <- stored) {
                out.collect("MATCH: " + i + " - " + value)
              }
            }

            // there is no else{} to cover if rule.first == rule.second
            if (shape.equals(rule.first)) stored.add(value)

            if (stored.isEmpty) state.remove(ruleName)
            else state.put(ruleName, stored)
          }
        }

        override def processBroadcastElement(value: Rule, ctx: KeyedBroadcastProcessFunction[Color, Item, Rule, String]#Context, out: Collector[String]) = {
          ctx.getBroadcastState(ruleStateDescriptor).put(value.name, value)
        }
      })

    out.print()
    env.execute("KeyedBroadcastProcessFunctionD_0003")
  }
}
